{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Matplotlib created a temporary config/cache directory at /tmp/matplotlib-q__yfjpz because the default path (/home/mikyas/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg \n",
    "import trt_pose.coco\n",
    "import math\n",
    "import os\n",
    "import numpy as np\n",
    "import traitlets\n",
    "import pickle \n",
    "import pyautogui\n",
    "import time "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('hand_pose.json', 'r') as f:\n",
    "    hand_pose = json.load(f)\n",
    "\n",
    "topology = trt_pose.coco.coco_category_to_topology(hand_pose)\n",
    "import trt_pose.models\n",
    "\n",
    "num_parts = len(hand_pose['keypoints'])\n",
    "num_links = len(hand_pose['skeleton'])\n",
    "\n",
    "model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval()\n",
    "import torch\n",
    "\n",
    "\n",
    "WIDTH = 224\n",
    "HEIGHT = 224\n",
    "data = torch.zeros((1, 3, HEIGHT, WIDTH)).cuda()\n",
    "\n",
    "if not os.path.exists('hand_pose_resnet18_att_244_244_trt.pth'):\n",
    "    MODEL_WEIGHTS = 'hand_pose_resnet18_att_244_244.pth'\n",
    "    model.load_state_dict(torch.load(MODEL_WEIGHTS))\n",
    "    import torch2trt\n",
    "    model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)\n",
    "    OPTIMIZED_MODEL = 'hand_pose_resnet18_att_244_244_trt.pth'\n",
    "    torch.save(model_trt.state_dict(), OPTIMIZED_MODEL)\n",
    "\n",
    "\n",
    "OPTIMIZED_MODEL = 'hand_pose_resnet18_att_244_244_trt.pth'\n",
    "from torch2trt import TRTModule\n",
    "\n",
    "model_trt = TRTModule()\n",
    "model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trt_pose.draw_objects import DrawObjects\n",
    "from trt_pose.parse_objects import ParseObjects\n",
    "\n",
    "parse_objects = ParseObjects(topology,cmap_threshold=0.15, link_threshold=0.15)\n",
    "draw_objects = DrawObjects(topology)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torchvision.transforms as transforms\n",
    "import PIL.Image\n",
    "\n",
    "mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()\n",
    "std = torch.Tensor([0.229, 0.224, 0.225]).cuda()\n",
    "device = torch.device('cuda')\n",
    "\n",
    "def preprocess(image):\n",
    "    global device\n",
    "    device = torch.device('cuda')\n",
    "    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "    image = PIL.Image.fromarray(image)\n",
    "    image = transforms.functional.to_tensor(image).to(device)\n",
    "    image.sub_(mean[:, None, None]).div_(std[:, None, None])\n",
    "    return image[None, ...]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.svm import SVC\n",
    "clf = make_pipeline(StandardScaler(), SVC(gamma='auto', kernel='rbf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from preprocessdata import preprocessdata\n",
    "preprocessdata = preprocessdata(topology, num_parts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "svm_train = False\n",
    "if svm_train:\n",
    "    clf, predicted = preprocessdata.trainsvm(clf, joints_train, joints_test, labels_train, hand.labels_test)\n",
    "    filename = 'svmmodel.sav'\n",
    "    pickle.dump(clf, open(filename, 'wb'))\n",
    "else:\n",
    "    filename = 'svmmodel.sav'\n",
    "    clf = pickle.load(open(filename, 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetcam.usb_camera import USBCamera\n",
    "from jetcam.csi_camera import CSICamera\n",
    "from jetcam.utils import bgr8_to_jpeg\n",
    "\n",
    "camera = USBCamera(width=WIDTH, height=HEIGHT, capture_fps=30, capture_device=1)\n",
    "#camera = CSICamera(width=WIDTH, height=HEIGHT, capture_fps=30)\n",
    "\n",
    "camera.running = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ae138f658a94450a9f45d6f81506aac0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Image(value=b'', format='jpeg', height='256', width='256')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import ipywidgets\n",
    "from IPython.display import display\n",
    "\n",
    "\n",
    "image_w = ipywidgets.Image(format='jpeg', width=224, height=224)\n",
    "display(image_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_joints(image, joints):\n",
    "    count = 0\n",
    "    for i in joints:\n",
    "        if i==[0,0]:\n",
    "            count+=1\n",
    "    if count>= 19:\n",
    "        return \n",
    "    for i in joints:\n",
    "        cv2.circle(image, (i[0],i[1]), 2, (0,0,255), 1)\n",
    "    cv2.circle(image, (joints[0][0],joints[0][1]), 2, (255,0,255), 1)\n",
    "    for i in hand_pose['skeleton']:\n",
    "        if joints[i[0]-1][0]==0 or joints[i[1]-1][0] == 0:\n",
    "            break\n",
    "        cv2.line(image, (joints[i[0]-1][0],joints[i[0]-1][1]), (joints[i[1]-1][0],joints[i[1]-1][1]), (0,255,0), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "screenWidth, screenHeight = pyautogui.size()\n",
    "p_text = 'none'\n",
    "p_sc = 0\n",
    "cur_x, cur_y = pyautogui.position()\n",
    "fixed_x, fixed_y = pyautogui.position()\n",
    "pyautogui.FAILSAFE = False\n",
    "t0 = time.time()\n",
    "def control_cursor(text, joints):\n",
    "    global p_text\n",
    "    global p_sc\n",
    "    global t0\n",
    "    global cur_x\n",
    "    global cur_y\n",
    "    global fixed_x,  fixed_y\n",
    "    cursor_joint = 6\n",
    "    if p_text!=\"pan\":\n",
    "        #pyautogui.position()\n",
    "        fixed_x = joints[cursor_joint][0]\n",
    "        fixed_y = joints[cursor_joint][1] \n",
    "    if p_text!=\"click\" and text==\"click\":\n",
    "        pyautogui.mouseUp(((joints[cursor_joint][0])*screenWidth)/256, ((joints[cursor_joint][1])*screenHeight)/256, button= 'left')\n",
    "        pyautogui.click()\n",
    "    if text == \"pan\":\n",
    "        if joints[cursor_joint]!=[0,0]:\n",
    "            pyautogui.mouseUp(((joints[cursor_joint][0])*screenWidth)/256, ((joints[cursor_joint][1])*screenHeight)/256, button= 'left')\n",
    "\n",
    "            pyautogui.moveTo(((joints[cursor_joint][0])*screenWidth)/256, ((joints[cursor_joint][1])*screenHeight)/256)\n",
    "    if text == \"scroll\":\n",
    "        \n",
    "        if joints[cursor_joint]!=[0,0] and joints[0]!=[0,0]:\n",
    "            pyautogui.mouseUp(((joints[cursor_joint][0])*screenWidth)/256, ((joints[cursor_joint][1])*screenHeight)/256, button= 'left')#to_scroll = (joints[8][1]-joints[0][1])/10\n",
    "            to_scroll = (p_sc-joints[cursor_joint][1])\n",
    "            if to_scroll>0:\n",
    "                to_scroll = 1\n",
    "            else:\n",
    "                to_scroll = -1\n",
    "            pyautogui.scroll(int(to_scroll),x=(joints[cursor_joint][0]*screenWidth)/256, y=(joints[cursor_joint][1]*screenHeight)/256)\n",
    "    if text == \"zoom\":\n",
    "        \n",
    "        \n",
    "        pyautogui.keyDown('ctrl')\n",
    "        if joints[cursor_joint]!=[0,0] and joints[0]!=[0,0]:\n",
    "            pyautogui.mouseUp(((joints[cursor_joint][0])*screenWidth)/256, ((joints[cursor_joint][1])*screenHeight)/256, button= 'left')\n",
    "            \n",
    "            to_scroll = (p_sc-joints[cursor_joint][1])\n",
    "            if to_scroll>0:\n",
    "                to_scroll = 1\n",
    "            else:\n",
    "                to_scroll = -1\n",
    "            t1 = time.time()\n",
    "            #print(t1-t0)\n",
    "            if t1-t0>1:   \n",
    "                pyautogui.scroll(int(to_scroll),x=(joints[cursor_joint][0]*screenWidth)/256, y=(joints[cursor_joint][1]*screenHeight)/256)\n",
    "                t0 = time.time()\n",
    "        pyautogui.keyUp('ctrl')\n",
    "        \n",
    "        \n",
    "    if text == \"drag\":\n",
    "        \n",
    "        if joints[cursor_joint]!=[0,0]:\n",
    "            pyautogui.mouseDown(((joints[cursor_joint][0])*screenWidth)/256, ((joints[cursor_joint][1])*screenHeight)/256, button= 'left')\n",
    "        \n",
    "        \n",
    "    p_text = text\n",
    "    p_sc = joints[cursor_joint][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def execute(change):\n",
    "    image = change['new']\n",
    "    data = preprocess(image)\n",
    "    cmap, paf = model_trt(data)\n",
    "    cmap, paf = cmap.detach().cpu(), paf.detach().cpu()\n",
    "    counts, objects, peaks = parse_objects(cmap, paf)#, cmap_threshold=0.15, link_threshold=0.15)\n",
    "    #draw_objects(image, counts, objects, peaks)\n",
    "    joints = preprocessdata.joints_inference(image, counts, objects, peaks)\n",
    "    \n",
    "    dist_bn_joints = preprocessdata.find_distance(joints)\n",
    "    gesture = clf.predict([dist_bn_joints,[0]*num_parts*num_parts])\n",
    "    gesture_joints = gesture[0]\n",
    "    preprocessdata.prev_queue.append(gesture_joints)\n",
    "    preprocessdata.prev_queue.pop(0)\n",
    "    preprocessdata.print_label(image, preprocessdata.prev_queue)\n",
    "    #draw_joints(image, joints)\n",
    "    control_cursor(preprocessdata.text, joints)\n",
    "    image_w.value = bgr8_to_jpeg(image)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "execute({'new': camera.value})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera.observe(execute, names='value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera.unobserve_all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#camera.running = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
